{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python379jvsc74a57bd0b157589e2ede00d340fed454223ce98f3e66982c0431b5c5286cc0a4d3cc5a4f",
   "display_name": "Python 3.7.9 64-bit"
  },
  "metadata": {
   "interpreter": {
    "hash": "b157589e2ede00d340fed454223ce98f3e66982c0431b5c5286cc0a4d3cc5a4f"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "source": [
    "import os\n",
    "from os.path import realpath\n",
    "import torch\n",
    "from skimage import io\n",
    "import numpy as np\n",
    "from util.config import cfg as test_cfg\n",
    "from data.test_dataset import TestDataset\n",
    "from util import util\n",
    "from models.networks import RainNet\n",
    "from models.normalize import RAIN\n",
    "import matplotlib.pyplot as plt\n",
    "from imageio import mimsave\n",
    "\n",
    "%matplotlib inline"
   ],
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "source": [
    "def load_network(cfg):\n",
    "    net = RainNet(input_nc=cfg.input_nc, \n",
    "                output_nc=cfg.output_nc, \n",
    "                ngf=cfg.ngf, \n",
    "                norm_layer=RAIN, \n",
    "                use_dropout=not cfg.no_dropout)\n",
    "    \n",
    "    load_path = os.path.join(cfg.checkpoints_dir, cfg.name, 'net_G.pth')\n",
    "    if not os.path.exists(load_path):\n",
    "        raise FileExistsError, print('%s not exists. Please check the file'%(load_path))\n",
    "    print(f'loading the model from {load_path}')\n",
    "    state_dict = torch.load(load_path)\n",
    "    util.copy_state_dict(net.state_dict(), state_dict)\n",
    "    # net.load_state_dict(state_dict)\n",
    "    return net\n",
    "\n",
    "def save_img(path, img):\n",
    "    os.makedirs(os.path.split(path)[0], exist_ok=True)\n",
    "    io.imsave(path, img)"
   ],
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "rainnet = load_network(test_cfg)"
   ],
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "source": [
    "comp_path = ['examples/1.png', 'examples/2.png', 'examples/3.png']\n",
    "mask_path = ['examples/1-mask.png', 'examples/2-mask.png', 'examples/3-mask.png']\n",
    "real_path = ['examples/1-gt.png', 'examples/2-gt.png', 'examples/3-gt.png']\n",
    "# load the testing set\n",
    "testdata = TestDataset(foreground_paths=comp_path, mask_paths=mask_path, background_paths=real_path, load_size=256)"
   ],
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "source": [
    "repeat_times = 0 # adjust the foreground image by several times\n",
    "for idx in tqdm(range(len(testdata))):\n",
    "    sample = testdata[idx]\n",
    "    # unsqueeze the data to shape of (1, channel, H, W)\n",
    "    comp = sample['comp'].unsqueeze(0).to(device)\n",
    "    mask = sample['mask'].unsqueeze(0).to(device) # if you want to adjust the background to be compatible with the foreground, then add the following command\n",
    "    # mask = 1 - mask\n",
    "    real = sample['real'].unsqueeze(0).to(device) # if the real_path is not given, then return composite image by sample['real']\n",
    "    img_path = sample['img_path']\n",
    "    pred = rainnet.processImage(comp, mask, real)\n",
    "    for i in range(repeat_times):\n",
    "        pred = rainnet.processImage(pred, mask, pred)\n",
    "        \n",
    "    # tensor2image\n",
    "    pred_rgb = util.tensor2im(pred[0:1])\n",
    "    comp_rgb = util.tensor2im(comp[:1])\n",
    "    mask_rgb = util.tensor2im(mask[:1])\n",
    "    real_rgb = util.tensor2im(real[:1])\n",
    "    print(img_path)\n",
    "    save_img(img_path.split('.')[0] + '-results.png', np.hstack([comp_rgb, mask_rgb, pred_rgb]))"
   ],
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  }
 ]
}